#include "statsxx/machine_learning/RBM.hpp"

// STL
#include <cmath>   // std::round()
#include <fstream> // std::ofstream
#include <limits>  // std::numeric_limits<>::max(), ::min()
#include <numeric> // std::iota()
#include <tuple>   // std::tuple<>, std::make_tuple()
#include <iomanip> // std::setw()

// jScience
#include "jScience/linalg.hpp"       // Matrix<>, Vector<>, outer_product()
#include "jScience/math/utility.hpp" // ngroups()
#include "jScience/stats/data.hpp"   // shuffle_data()
#include "jScience/stl/vector.hpp"   // subvector()
#include "jrandnum.hpp"              // rand_num_normal_Mersenne_twister()
#include "jutility.hpp"              // get_n_rand_unique_elements()

// stats++
#include "statsxx/distribution.hpp"  // distribution::binomial()
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic, activation_function::ReLU, activation_function::softplus
#include "statsxx/statistics.hpp"    // statistics::mean()


// using namespace machine_learning;


//
// TODO: one thing that I had, but toook out, was the printing of some reconstructions as time goes on (this shows the improvement of reconstruction as time goes on) ... but in more numeric terms this is shown in the reconstruction file
//
// TODO: does CD_n approach the correct maximum likelihood as n increases? I thought that I or Jeevs read this.
//
// =====
//
//
// TODO: NOTE: might be able hold out a validation set automatically, but as of now, a validation set is passed in so the user has more control
//
// TODO: figure out how to automatically determine when to stop training:
// --- NOTE: this should work without storing all the weights (like NN training does)
// --- maybe approximate curve by a rational function (since there are horizontal asymptotes) and find the inflection point?
// --- maybe define windows over which to average the results, and then stop based on some criteria?


//
// DESC: New RBM
//
// NOTE: the learning rate is adjusted automatically, by determining the cosine of the angle between the current gradient and that prior. If the cosine is small (meaning that gradients point in nealy the same direction), we increase the learning rate; and if it is large (meaning that gradients point in nearly opposite directions), we decrease the learning rate. Cutoffs and adjustments for this are lr_dot and lr_adj, respectively.
// NOTE: ... see: T. Masters, "Deep Belief Nets in C++ and CUDA C" pp. 138-139 and 159
//
// NOTE: two convergence criteria are implemented:
//     (i) the max magnitude gradient to max magnitude weight for a given epoch is below a threshold (convg_criterion_max_inc_max_w)
//     (ii) the max magnitude gradient to max magnitude weight fails to decrease for a successive number of steps (convg_criterion_nno_improvement)
// NOTE: ... see: T. Masters, "Deep Belief Nets in C++ and CUDA C" pp. 140--141 and 160
//
//
// NOTE:
//     - for some simple implementations of RBMs, see:
//          http://www.cs.toronto.edu/~hinton/code/rbm.m
//          http://www.cs.toronto.edu/~hinton/code/rbmhidlinear.m
//     for Binary--Binary and Binary--Gaussian RBMs, respectively
//
inline void machine_learning::RBM::train(
                                         const int                  nepoch,
                                         const int                  nbatches,
                                         // -----
                                         const int                  K_start,            // starting number of CD steps
                                         const int                  K_end,              // ending number
                                         const double               K_rate,             // rate at which K_start transitions to K_end
                                         // -----
                                         const bool                 sample_v,           // whether to sample v during reconstructions --- this is correct and may lead to better density models (Hinton, practical guide), but not reduces sampling noise thus allowing faster learning
                                         const bool                 sample_hdata,       // whether to sample h for the positive statistics --- true is closer to the mathematical model of an RBM
                                         // -----
                                         double                     lr,
                                         const std::vector<double> &lr_dot,
                                         const std::vector<double> &lr_adj,
                                         const double               lr_min,
                                         const double               lr_max,
                                         // ---
                                         const double               momentum_start,
                                         const double               momentum_end,
                                         const double               momentum_rate,
                                         const std::vector<double> &momentum_dot,
                                         const std::vector<double> &momentum_adj,
                                         // -----
                                         const double               weight_penalty,
                                         // -----
                                         const int                  free_energy_npts,   // number of training points to (randomly) select for free energy
                                         const int                  free_energy_nepoch, // evaluate the free energy every number of epochs
                                         const int                  free_energy_window, // calculate the slope of the free energy over this window
                                         const double               convg_criterion_max_inc_max_w,
                                         const int                  convg_criterion_nno_improvement,
                                         // -----
                                         const Matrix<double>      &X,                  // training set
                                         const Matrix<double>      &X_val,              // validation set
                                         // -----
                                         const std::string          prefix
                                         )
{
    //=========================================================
    // INITIALIZE
    //=========================================================
    
    // --------------------------------------------------------
    // WEIGHTS AND BIASES
    // --------------------------------------------------------
    
    W = Matrix<double>(nv,nh);
    for(auto i = 0; i < nv; ++i)
    {
        for(auto j = 0; j < nh; ++j)
        {
            W(i,j) = 0.1*rand_num_normal_Mersenne_twister(0.,1.);
        }
    }
    
    a = Vector<double>(nv, 0.);
    b = Vector<double>(nh, 0.);
    
    // ----- GRADIENTS -----
    
    Matrix<double> W_inc(nv,nh, 0.0);
    Vector<double> a_inc(nv, 0.0);
    Vector<double> b_inc(nh, 0.0);
    
    // ----- ACTIVATION FUNCTIONS -----
    
    activation_function::Logistic logistic;
    activation_function::ReLU     ReLU;
    activation_function::softplus softplus;
    
    // --------------------------------------------------------
    // TRAINING
    // --------------------------------------------------------
    
    // ----- BATCHES -----
    
    std::vector<int> shuffle_idx(X.size(0));
    std::iota(shuffle_idx.begin(), shuffle_idx.end(), 0);
    
    std::vector<int> batch_begin;
    std::vector<int> batch_end;
    std::vector<int> batch_size;
    std::tie(
             batch_begin,
             batch_end,
             batch_size
             ) = ngroups(
                         X.size(0),
                         nbatches
                         );
    
    // ----- CONTRASTIVE DIVERGENCE -----
    
    int K = K_start;
    
    double one_m_K_rate = 1. - K_rate;
    
    // ----- LEARNING RATE AND MOMENTUM -----
    
    double momentum = momentum_start;
    
    double one_m_momentum_rate = 1. - momentum_rate;
    
    Matrix<double> W_inc_prev;
    double len_prev;
    
    // ----- CONVERGENCE CRITERIA -----
    
    std::vector<int> free_energy_idx = get_n_rand_unique_elements(
                                                                  free_energy_npts,
                                                                  shuffle_idx
                                                                  );
    
    // NOTE: as discussed in Hinton "A Practical Guide to Training Restricted Boltzmann Machines", it is important to use the SAME subset of training data to calculate relative free energies
    
    Matrix<double> X_free_energy(free_energy_npts,X.size(1));
    
    for(auto i = 0; i < free_energy_npts; ++i)
    {
        for(auto j = 0; j < X.size(1); ++j)
        {
            X_free_energy(i,j) = X(free_energy_idx[i],j);
        }
    }
    
    double convg_criterion_min_max_inc_max_w = std::numeric_limits<double>::max();
    int nno_improvement = 0;
    
    // --------------------------------------------------------
    // OUTPUT
    // --------------------------------------------------------
    
    std::ofstream ofs_free_energy((prefix + ".free_energy.dat"),std::ios::out);
    std::ofstream ofs_reconstruction((prefix + ".reconstruction.dat"),std::ios::out);
    
    //=========================================================
    // TRAIN
    //=========================================================
    
    for(int epoch = 0; epoch < nepoch; ++epoch)
    {
        shuffle_data(shuffle_idx);
        
        double max_w_inc = std::numeric_limits<double>::min();
        
        for(int batch = 0; batch < nbatches; ++batch)
        {
            Matrix<double> W_inc_batch(nv,nh, 0.0);
            Vector<double> a_inc_batch(nv, 0.0);
            Vector<double> b_inc_batch(nh, 0.0);
            
            for(int i = batch_begin[batch]; i <= batch_end[batch]; ++i)
            {
                // definitions:
                //
                // - v_data  : visible state from data distribution
                //
                // - ph_data  --- q_data : "probability" that each hidden neuron will be one, under the data distribution
                //
                // - h_data  : hidden state from data distribution
                //
                // ---
                //
                // - ph_model --- q_model: "probability" that each hidden neuron will be one, under the model distribution
                //
                // - h_model  --- h_model: hidden neuron activations, under the model distribution
                //
                // - pv_model --- p_model: reconstruction "probability" under the model distribution
                //
                // - v_model  --- v_model: reconstructed visible neuron activations, under the model distribution
                
                Vector<double> v_data = X.row(shuffle_idx[i]);
                
                // ----- START POSITIVE PHASE -----
                
                // NOTE: the positive phase collects the statistics <vi*hj>_data
                
                Vector<double> xh = transpose(this->W)*v_data + this->b;

                Vector<double> ph_data = this->calculate_p(
                                                           xh,
                                                           this->htype
                                                           );
                
                // NOTE: h_data is set below
                Vector<double> h_data;

                // ----- END POSITIVE PHASE -----
                
                // ----- START NEGATIVE PHASE -----
                
                // NOTE: the negative phase collects the statistics <vi*hj>_model
                
                Vector<double> v_model;
                Vector<double> ph_model = ph_data;
                
                for(int k = 0; k < K; ++k)
                {
                    // ----- ----- START SAMPLING ----- ----- 
                    
                    Vector<double> h_model = this->sample(
                                                          xh,
                                                          ph_model,
                                                          this->htype
                                                          );
                    
                    // ----- ----- END SAMPLING ----- ----- 

                    // NOTE: the following isn't actually part of the negative phase, but we deferred setting h_data above (because it may be set to h_model)
                    
                    if(k == 0)
                    {
                        if(sample_hdata)
                        {
                            // NOTE: Footnote 2 on p. 6 of Hinton's "A Practical Guide to Training Restricted Boltzmann Machines" makes it clear that if h_data is to be sampled (from ph_data) that it is the same as that used for the reconstruction (which is how there can be less noise in the difference of the positive and negative statistics --- which is referring to CD_1) 
                            h_data = h_model;
                        }
                        else
                        {
                            h_data = ph_data;
                        }
                    }

                    // ----- ----- START RECONSTRUCTION ----- ----- 
                    
                    Vector<double> xv = this->W*h_model + this->a;
                    
                    Vector<double> pv_model = this->calculate_p(
                                                                xv,
                                                                this->vtype
                                                                );
                    
                    if(sample_v)
                    {
                        v_model = this->sample(
                                               xv,
                                               pv_model,
                                               this->vtype
                                               );
                    }
                    else
                    {
                        v_model = pv_model;
                    }
                    
                    // ----- ----- END RECONSTRUCTION ----- -----
                    
                    // ----- ----- START CALCULATION OF HIDDEN STATE ----- ----- 
                    
                    xh = transpose(this->W)*v_model + this->b;
                    
                    ph_model = this->calculate_p(
                                                 xh,
                                                 this->htype
                                                 );
                    
                    // ----- ----- END CALCULATION OF HIDDEN STATE ----- ----- 
                }
                
                // ----- END NEGATIVE PHASE -----
                
                // ----- UPDATE GRADIENTS -----
                                
                // Update rules for the parameters:
                //     delta W_ij = <vi*hj>_data - <vi*hj>_model
                //     delta a_i  = <vi>_data - <vi>_model
                //     delta b_j  = <hj>_data - <hj>_model
                
                // NOTE: in the following, h_data may actually be ph_data (if sample_hdata is false)
                
                // NOTE: for the last update of the hidden units, h does not need to be sampled (because nothing depends on which state is chosen) --- this is discussed on p. 5 of Hinton's "A Practical Guide to Training Restricted Boltzmann Machines" --- (hence the use of ph_model)

                W_inc_batch += (outer_product(v_data, h_data) - outer_product(v_model, ph_model));
                a_inc_batch += (v_data - v_model);
                b_inc_batch += (h_data - ph_model);
            } // ++i
            
            // NOTE: batch gradients are always divided by the number of training cases, so that the learning rate multiples the average, per-case gradient
            W_inc_batch /= X.size(0);
            a_inc_batch /= X.size(0);
            b_inc_batch /= X.size(0);
            
            // ----- UPDATE WEIGHTS AND BIASES -----
            
            W_inc = momentum*W_inc + lr*(W_inc_batch - weight_penalty*W);
            a_inc = momentum*a_inc + lr*a_inc_batch;
            b_inc = momentum*b_inc + lr*b_inc_batch;
            
            W += W_inc;
            a += a_inc;
            b += b_inc;
            
            // ----- CONVERGENCE CRITERION -----
            // NOTE: every batch is updated so that we can find the greatest weight gradient for the entire epoch
            for(auto i = 0; i < W_inc.size(0); ++i)
            {
                for(auto j = 0; j < W_inc.size(1); ++j)
                {
                    double w_inc = std::fabs(W_inc(i,j));
                    
                    if(w_inc > max_w_inc)
                    {
                        max_w_inc = w_inc;
                    }
                }
            }
            
            // --------------------------------------------------------
            // ADJUST LEARNING RATE (AND MOMENTUM <-- NOT YET)
            // --------------------------------------------------------

            // NOTE: as shown in T. Masters, "Deep Belief Nets in C++ and CUDA C" on pp. 158--159, this update occurs inside the batch loop
            
            if((epoch == 0) && (batch == 0))
            {
                W_inc_prev = W_inc;
                
                len_prev = 0.;
                
                for(auto i = 0; i < W_inc_prev.size(0); ++i)
                {
                    for(auto j = 0; j < W_inc_prev.size(1); ++j)
                    {
                        len_prev += (W_inc_prev(i,j)*W_inc_prev(i,j));
                    }
                }
            }
            else
            {
                double dot = 0.;
                
                double len = 0.;
                
                for(auto i = 0; i < W_inc.size(0); ++i)
                {
                    for(auto j = 0; j < W_inc.size(1); ++j)
                    {
                        len += (W_inc(i,j)*W_inc(i,j));
                        
                        dot += (W_inc(i,j)*W_inc_prev(i,j));
                    }
                }
                
                // cosine of the angle between the current gradient and that prior
                dot /= std::sqrt(len*len_prev);
                
                // ----- ADJUST LEARNING RATE -----
                for(auto i = 0; i < lr_dot.size(); ++i)
                {
                    if(dot > lr_dot[i])
                    {
                        lr *= lr_adj[i];
                        
                        break;
                    }
                    else if(dot < -lr_dot[i])
                    {
                        lr /= lr_adj[i];
                        
                        break;
                    }
                }
                
                lr = std::min(lr, lr_max);
                lr = std::max(lr, lr_min);
                
                // ----- ADJUST MOMENTUM -----
                for(auto i = 0; i < momentum_dot.size(); ++i)
                {
                    if(std::fabs(dot) > momentum_dot[i])
                    {
                        momentum /= momentum_adj[i];
                        
                        break;
                    }
                    // NOTE: it is a source of instability to increase the momentum if the angle is small
                }
                
                W_inc_prev = W_inc;
                
                len_prev = len;
            }
            
        } // ++batch
        
        // --------------------------------------------------------
        // OUTPUT (FREE ENERGIES AND RECONSTRUCTION)
        // --------------------------------------------------------

        if((epoch % free_energy_nepoch) == 0)
        {
            // ----- FREE ENERGIES -----
            std::vector<double> free_energy_X = this->free_energy(
                                                                  X_free_energy
                                                                  );
            
            double avg_free_energy_X = statistics::mean(
                                                        free_energy_X
                                                        );
            
            std::vector<double> free_energy_X_val = this->free_energy(
                                                                      X_val
                                                                      );
            
            double avg_free_energy_X_val = statistics::mean(
                                                            free_energy_X_val
                                                            );
            
            // [epoch #] [p(v_val)/p(v_tr)] [F(v_tr)] [F(v_val)] [dF(v_tr)/depoch] [dF(v_val)/depoch]
            ofs_free_energy << epoch << "   " << std::exp(avg_free_energy_X - avg_free_energy_X_val) << "   " << avg_free_energy_X << "   " << avg_free_energy_X_val << '\n';
            
            // ----- RECONSTRUCTION -----
            double sse = 0.;
            
            for(auto i = 0; i < X_free_energy.size(0); ++i)
            {
                Vector<double> x_reconstruction;
                Vector<double> x_hidden;
                std::tie(
                         x_reconstruction,
                         x_hidden
                         ) = this->reconstruct(
                                               X_free_energy.row(i)
                                               );
                
                Vector<double> diff = x_reconstruction - X_free_energy.row(i);
                sse += dot_product(
                                   diff,
                                   diff
                                   );
            }
            
            sse /= X_free_energy.size(0);
            
            double sse_val = 0.;
            
            for(auto i = 0; i < X_val.size(0); ++i)
            {
                Vector<double> x_reconstruction;
                Vector<double> x_hidden;
                std::tie(
                         x_reconstruction,
                         x_hidden
                         ) = this->reconstruct(
                                               X_val.row(i)
                                               );
                
                Vector<double> diff = x_reconstruction - X_val.row(i);
                sse_val += dot_product(
                                       diff,
                                       diff
                                       );
            }
            
            sse_val /= X_val.size(0);
            
            ofs_reconstruction << epoch << "   " << (sse_val/sse) << "   " << sse << "   " << sse_val << '\n';
            
        }

        // --------------------------------------------------------
        // CONVERGENCE
        // --------------------------------------------------------
        
        double max_w = std::numeric_limits<double>::min();
        
        for(auto i = 0; i < W.size(0); ++i)
        {
            for(auto j = 0; j < W.size(1); ++j)
            {
                double w = std::fabs(W(i,j));
                
                if(w > max_w)
                {
                    max_w = w;
                }
            }
        }
        
        double max_inc_max_w = max_w_inc/max_w;
        
        // CONVERGENCE CRITERION 1: maximum weight gradient (to weight) is below threshold
        if(max_inc_max_w < convg_criterion_max_inc_max_w)
        {
            std::cout << "breaking on condition (max_inc_max_w < convg_criterion_max_inc_max_w)" << '\n';
            break;
        }
        
        if(max_inc_max_w < convg_criterion_min_max_inc_max_w)
        {
            convg_criterion_min_max_inc_max_w = max_inc_max_w;
            
            nno_improvement = 0;
        }
        else
        {
            ++nno_improvement;
            
            // CONVERGENCE CRITERION 2: gradient (to weight) is not reducing too many successive times in a row
            if(nno_improvement > convg_criterion_nno_improvement)
            {
                std::cout << "breaking on condition (nno_improvement > convg_criterion_nno_improvement)" << '\n';
                break;
            }
        }
        
        // --------------------------------------------------------
        // UPDATE MOMENTUM AND CHAIN LENGTHS
        // --------------------------------------------------------
        
        momentum = one_m_momentum_rate*momentum + momentum_rate*momentum_end;
        
        // NOTE: the following leads to jumps in K ... not sure whether that is of any problem
        K        = static_cast<int>(std::round(one_m_K_rate*K + K_rate*K_end));
    } // ++epoch
    
    //=========================================================
    // CLEANUP
    //=========================================================

    ofs_free_energy.close();
    ofs_reconstruction.close();
}